{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d7e6ab09",
   "metadata": {},
   "source": [
    "# A deep dive into robomimic datasets\n",
    "\n",
    "This notebook will provide examples on how to work with robomimic datasets through various python code examples. This notebook assumes that you have installed `robomimic` and `robosuite` (which should be on the `offline_study` branch)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a05e543",
   "metadata": {},
   "source": [
    "## Download dataset\n",
    "\n",
    "First, let's try downloading a simple dataset - we'll use the Lift (PH) dataset. Note that there are utility scripts such as `scripts/download_datasets.py` to do this for us, but for the purposes of this example, we'll use the python API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e2b90e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import h5py\n",
    "import numpy as np\n",
    "\n",
    "import robomimic\n",
    "import robomimic.utils.file_utils as FileUtils\n",
    "\n",
    "# the dataset registry can be found at robomimic/__init__.py\n",
    "from robomimic import DATASET_REGISTRY\n",
    "\n",
    "# set download folder and make it\n",
    "download_folder = \"/tmp/robomimic_ds_example\"\n",
    "os.makedirs(download_folder, exist_ok=True)\n",
    "\n",
    "# download the dataset\n",
    "task = \"lift\"\n",
    "dataset_type = \"ph\"\n",
    "hdf5_type = \"low_dim\"\n",
    "FileUtils.download_url(\n",
    "    url=DATASET_REGISTRY[task][dataset_type][hdf5_type][\"url\"], \n",
    "    download_dir=download_folder,\n",
    ")\n",
    "\n",
    "# enforce that the dataset exists\n",
    "dataset_path = os.path.join(download_folder, \"low_dim_v141.hdf5\")\n",
    "assert os.path.exists(dataset_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54bdec82",
   "metadata": {},
   "source": [
    "## Read quantities from dataset\n",
    "\n",
    "Next, let's demonstrate how to read different quantities from the dataset. There are scripts such as `scripts/get_dataset_info.py` that can help you easily understand the contents of a dataset, but in this example, we'll break down how to do this directly.\n",
    "\n",
    "First, let's take a look at the number of demonstrations in the file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a35cd8e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# open file\n",
    "f = h5py.File(dataset_path, \"r\")\n",
    "\n",
    "# each demonstration is a group under \"data\"\n",
    "demos = list(f[\"data\"].keys())\n",
    "num_demos = len(demos)\n",
    "\n",
    "print(\"hdf5 file {} has {} demonstrations\".format(dataset_path, num_demos))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdb073a0",
   "metadata": {},
   "source": [
    "Next, let's list all of the demonstrations, along with the number of state-action pairs in each demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bda3e70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# each demonstration is named \"demo_#\" where # is a number.\n",
    "# Let's put the demonstration list in increasing episode order\n",
    "inds = np.argsort([int(elem[5:]) for elem in demos])\n",
    "demos = [demos[i] for i in inds]\n",
    "\n",
    "for ep in demos:\n",
    "    num_actions = f[\"data/{}/actions\".format(ep)].shape[0]\n",
    "    print(\"{} has {} samples\".format(ep, num_actions))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff998d62",
   "metadata": {},
   "source": [
    "Now, let's dig into a single trajectory to take a look at some of the quantities in each demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f7b497a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# look at first demonstration\n",
    "demo_key = demos[0]\n",
    "demo_grp = f[\"data/{}\".format(demo_key)]\n",
    "\n",
    "# Each observation is a dictionary that maps modalities to numpy arrays, and\n",
    "# each action is a numpy array. Let's print the observations and actions for the \n",
    "# first 5 timesteps of this trajectory.\n",
    "for t in range(5):\n",
    "    print(\"timestep {}\".format(t))\n",
    "    obs_t = dict()\n",
    "    # each observation modality is stored as a subgroup\n",
    "    for k in demo_grp[\"obs\"]:\n",
    "        obs_t[k] = demo_grp[\"obs/{}\".format(k)][t] # numpy array\n",
    "    act_t = demo_grp[\"actions\"][t]\n",
    "    \n",
    "    # pretty-print observation and action using json\n",
    "    obs_t_pp = { k : obs_t[k].tolist() for k in obs_t }\n",
    "    print(\"obs\")\n",
    "    print(json.dumps(obs_t_pp, indent=4))\n",
    "    print(\"action\")\n",
    "    print(act_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "552be387",
   "metadata": {},
   "outputs": [],
   "source": [
    "# we can also grab multiple timesteps at once directly, or even the full trajectory at once\n",
    "first_ten_actions = demo_grp[\"actions\"][:10]\n",
    "print(\"shape of first ten actions {}\".format(first_ten_actions.shape))\n",
    "all_actions = demo_grp[\"actions\"][:]\n",
    "print(\"shape of all actions {}\".format(all_actions.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57976238",
   "metadata": {},
   "outputs": [],
   "source": [
    "# the trajectory also contains the next observations under \"next_obs\", \n",
    "# for convenient use in a batch (offline) RL pipeline. Let's verify\n",
    "# that \"next_obs\" and \"obs\" are offset by 1.\n",
    "for k in demo_grp[\"obs\"]:\n",
    "    # obs_{t+1} == next_obs_{t}\n",
    "    assert(np.allclose(demo_grp[\"obs\"][k][1:], demo_grp[\"next_obs\"][k][:-1]))\n",
    "print(\"success\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51ab4a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# we also have \"done\" and \"reward\" information stored in each trajectory.\n",
    "# In this case, we have sparse rewards that indicate task completion at\n",
    "# that timestep.\n",
    "dones = demo_grp[\"dones\"][:]\n",
    "rewards = demo_grp[\"rewards\"][:]\n",
    "print(\"dones\")\n",
    "print(dones)\n",
    "print(\"\")\n",
    "print(\"rewards\")\n",
    "print(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "360df27c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# each demonstration also contains metadata\n",
    "num_samples = demo_grp.attrs[\"num_samples\"] # number of samples in this trajectory\n",
    "mujoco_xml_file = demo_grp.attrs[\"model_file\"] # mujoco XML file for this demonstration\n",
    "print(mujoco_xml_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f10f98f",
   "metadata": {},
   "source": [
    "Finally, let's take a look at some global metadata present in the file. The hdf5 file stores environment metadata which is a convenient way to understand which simulation environment (task) the dataset was collected on. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b579caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_meta = json.loads(f[\"data\"].attrs[\"env_args\"])\n",
    "# note: we could also have used the following function:\n",
    "# env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path)\n",
    "print(\"==== Env Meta ====\")\n",
    "print(json.dumps(env_meta, indent=4))\n",
    "print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b395453a",
   "metadata": {},
   "source": [
    "## Visualizing demonstration trajectories\n",
    "\n",
    "Finally, let's play some of these demonstrations back in the simulation environment to easily visualize the data that was collected."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d613ab93",
   "metadata": {},
   "source": [
    "It turns out that the environment metadata stored in the hdf5 allows us to easily create a simulation environment that is consistent with the way the dataset was collected!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c98068e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import robomimic.utils.env_utils as EnvUtils\n",
    "\n",
    "# create simulation environment from environment metedata\n",
    "env = EnvUtils.create_env_from_metadata(\n",
    "    env_meta=env_meta, \n",
    "    render=False,            # no on-screen rendering\n",
    "    render_offscreen=True,   # off-screen rendering to support rendering video frames\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "595a47d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import robomimic.utils.obs_utils as ObsUtils\n",
    "\n",
    "# We normally need to make sure robomimic knows which observations are images (for the\n",
    "# data processing pipeline). This is usually inferred from your training config, but\n",
    "# since we are just playing back demonstrations, we just need to initialize robomimic\n",
    "# with a dummy spec.\n",
    "dummy_spec = dict(\n",
    "    obs=dict(\n",
    "            low_dim=[\"robot0_eef_pos\"],\n",
    "            rgb=[],\n",
    "        ),\n",
    ")\n",
    "ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d997cf39",
   "metadata": {},
   "outputs": [],
   "source": [
    "import imageio\n",
    "\n",
    "# prepare to write playback trajectories to video\n",
    "video_path = os.path.join(download_folder, \"playback.mp4\")\n",
    "video_writer = imageio.get_writer(video_path, fps=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfae1aaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def playback_trajectory(demo_key):\n",
    "    \"\"\"\n",
    "    Simple helper function to playback the trajectory stored under the hdf5 group @demo_key and\n",
    "    write frames rendered from the simulation to the active @video_writer.\n",
    "    \"\"\"\n",
    "    \n",
    "    # robosuite datasets store the ground-truth simulator states under the \"states\" key.\n",
    "    # We will use the first one, alone with the model xml, to reset the environment to\n",
    "    # the initial configuration before playing back actions.\n",
    "    init_state = f[\"data/{}/states\".format(demo_key)][0]\n",
    "    model_xml = f[\"data/{}\".format(demo_key)].attrs[\"model_file\"]\n",
    "    initial_state_dict = dict(states=init_state, model=model_xml)\n",
    "    \n",
    "    # reset to initial state\n",
    "    env.reset_to(initial_state_dict)\n",
    "    \n",
    "    # playback actions one by one, and render frames\n",
    "    actions = f[\"data/{}/actions\".format(demo_key)][:]\n",
    "    for t in range(actions.shape[0]):\n",
    "        env.step(actions[t])\n",
    "        video_img = env.render(mode=\"rgb_array\", height=512, width=512, camera_name=\"agentview\")\n",
    "        video_writer.append_data(video_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "926d1811",
   "metadata": {},
   "outputs": [],
   "source": [
    "# playback the first 5 demos\n",
    "for ep in demos[:5]:\n",
    "    print(\"Playing back demo key: {}\".format(ep))\n",
    "    playback_trajectory(ep)\n",
    "\n",
    "# done writing video\n",
    "video_writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc89c8d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# view the trajectories!\n",
    "from IPython.display import Video\n",
    "Video(video_path, embed=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
